# -*- coding: utf-8 -*-
"""
Created on Sat Feb 16 14:38:02 2019

@author: Phillip Wolf from Emory University
@Helpful comments from Lee Pang from Amazon
"""

import time
import logging
import argparse

import numpy as np
#from scipy import spatial
from numba import jit
from dask import compute, delayed
import dask.threaded
import dask.multiprocessing

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

ch = logging.StreamHandler()
ch.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(ch)

parser = argparse.ArgumentParser()

parser.add_argument(
    '--parallelized',
    action='store_true',
    help="""
        Run with parallelization (via dask) enabled.
    """
)

parser.add_argument(
    '--scheduler',
    default='threads',
    help="""
        Parallelization method to use if parallelization is enabled.
    """
)


def read_in_list(inputfile):
    select_list = []
    with open(inputfile,'r') as infile:
        for row in infile:
            row = row.rstrip('\n')
            select_list.append(row)
    return select_list

def read_in_vectors(inputfile,vocab_to_use):
    labels = []
    vectors = []
    labels_to_define_to_num_dic = {}
    index = 0
    with open(inputfile,'r') as infile:
        for row in infile:
            line = row.split('\t')[1:]
            labels.append(row.split('\t')[0])
            labels_to_define_to_num_dic[row.split('\t')[0]] = index
            line = list(map(float,line))
            vectors.append(line)
            index += 1
    return np.array(vectors),labels_to_define_to_num_dic

def initialize_weights(vocab_to_use,word_to_be_defined):
    def clean(neighbor):
        if '+' in neighbor:
            neighbor = neighbor[:neighbor.find('+')]
        return neighbor
    
    labels_to_use = []
    for i in range(len(vocab_to_use)):
        labels_to_use.append(clean(vocab_to_use[i]))
    
    initial_weight = 1.0/len(vocab_to_use)
    weights = [initial_weight]*len(vocab_to_use)
    
    for i in range(len(vocab_to_use)):    
        if vocab_to_use[i][1].isnumeric(): #set the vector for word_to_be_defined to 0 if word is sentence
            weights[i] = 0.0
        if vocab_to_use[i] == word_to_be_defined: #set the vector for the word_to_be_defined to 0
            weights[i] = 0.0
    return np.array(weights)

def calc_gradients2(word_to_be_defined_num,vector_array,weights,eta):  
    #alternate function
    weighted_vector = weights*vector_array.transpose() # rows are weighted values, columns are word_in_definitions to use in definitions 
    predicted = weighted_vector.sum(axis=1)
    target = vector_array[word_to_be_defined_num]
    delta = eta * (target - predicted)
    gradients_within_word = delta * vector_array
    euc_gradients = gradients_within_word.sum(axis=1) # add up the gradients for the entire word_in_definition vector, which becomes the gradient for that word 
    return euc_gradients

def calc_gradients(word_to_be_defined_num,vector_array,weights,eta):  
    #gradient descent
    # use numpy element-wise multiplication coupled with array broadcasting
    predicted = (weights[:,None] * vector_array).sum(axis=0)
    target = vector_array[word_to_be_defined_num]
    delta = eta * (target - predicted)
    
    # use element-wise multiplication to compute the gradients
    euc_gradients = (delta[None, :] * vector_array).sum(axis=1)    
    return euc_gradients

def minimize(word_to_be_defined_num, vector_array, weights, Nupdates = 5000, eta = .001, tau = 100):
    # increasing tau makes it harder to assign a zero and gives more time for 
    # dis-similar components to emerge; increasting iterations increases 
    # number of zero-weight components
    tau2 = 1.0/(tau*Nupdates)
    for i in range(Nupdates):   #number of times the weights will get updated
        euc_gradients = calc_gradients(word_to_be_defined_num, vector_array, weights, eta)
        cutoff = tau2*i
        w_update = weights + euc_gradients #*zeros_filter #if the weights are at zero, do not update weights
        
        # weights that are already at 0, including the word to be defined, should not be updated
        lx_nz = weights != 0.0  # this also prevents the word from being used in its own definition
        weights[lx_nz] = w_update[lx_nz]
        weights[(lx_nz) & (weights < cutoff)] = 0.0
    return weights

def combine_words_and_weights(word_to_be_defined,weights,vocab_to_use):
    components = []
    for i in range(len(weights)):
        components.append([vocab_to_use[i], weights[i]])
        components.sort(key=lambda x: x[1], reverse=True) # sort the components so that the components with the largest weights come first
    #componts_dic[word_to_be_defined] = components
    return components

def calculate_distance(word_to_be_defined_num, components, vector_array,labels_to_define_to_num_dic):
    word_to_be_def_vector = vector_array[word_to_be_defined_num]
    estimated_components = components
    estimated_vector = [0]*len(word_to_be_def_vector)
    estimated_vector = np.array(estimated_vector)
    for i in range(len(estimated_components)):
        weight = estimated_components[i][1]
        index = labels_to_define_to_num_dic[estimated_components[i][0]]
        estimated_vector = estimated_vector + weight * vector_array[index]
    dist = np.linalg.norm(word_to_be_def_vector - estimated_vector)
    return dist  

def save_to_disk(components_dic,distance_dic,outfilename):
    outfile = open(outfilename,'w')
    defined_vectors = list(components_dic.keys())
    defined_vectors.sort()
    for word_in_definition in defined_vectors:
        line = word_in_definition+'\t'+str(distance_dic[word_in_definition])+'\t'
        for i in range(len(components_dic[word_in_definition])):
            if components_dic[word_in_definition][i][1] != 0:
                line = line + components_dic[word_in_definition][i][0] + ' ' + str(components_dic[word_in_definition][i][1]) + '\t'
        line = line.rstrip('\t')
        outfile.write(line+'\n')
    outfile.close()

def define_word(word_to_be_defined, vocab_to_use, vector_array, labels_to_define_to_num_dic):
    logger.info(f"word definition started for: {word_to_be_defined}")
    tstart = time.time()
    weights = initialize_weights(vocab_to_use,word_to_be_defined) 
    word_to_be_defined_num = labels_to_define_to_num_dic[word_to_be_defined]
 
    weights = minimize(word_to_be_defined_num, vector_array, weights)  #iterations = 1000, eta = .001, tau = 100 
    components = combine_words_and_weights(word_to_be_defined,weights,vocab_to_use)
    distance = calculate_distance(word_to_be_defined_num, components, vector_array,labels_to_define_to_num_dic)    
    
    tend = time.time()
    logger.info(f"word definition completed for: {word_to_be_defined} ({tend - tstart:.4f}s elapsed)")
    return {
        "word": word_to_be_defined, 
        "components": components, 
        "distance": distance
    }

def main(args):
    start_time = time.time()

    # Each sentence vector will require roughly 3-4 minutes to unpack 
    # Multiple sentences can be unpacked simultaneously by adding --parallelized to the the command line
    # If parallelization is indicated, he program will attempt to process all of sentences indicated in the "to_define" file.
    # For small numbers of sentences (e.g., 4), parallelization helps, but for large numbers, the overhead slows down processing.

    inputfile              = 'nyt_d200_w5_13592_1.txt'              #These are the vectors for all of the sentence and words
    vocab_to_define_file   = 'nyt_d200_w5_13592_to_define_one_sent.txt'  #These are the sentences or words to unpack
    vocabulary_to_use_file = 'nyt_d200_w5_13592_to_use.txt'         #These are the words to use in the unpacking
    
    outfilename1           = 'nyt_d200_w5_13592_comp.txt'           #This is the file to send the results to

    vocab_to_define = read_in_list(vocab_to_define_file)
    vocab_to_use = read_in_list(vocabulary_to_use_file)
    vector_array, labels_to_define_to_num_dic = read_in_vectors(inputfile,vocab_to_use)

    components_dic = {}
    distance_dic = {}
    
    if args.parallelized:
        values = [
            delayed(define_word)(word_to_be_defined, vocab_to_use, vector_array, labels_to_define_to_num_dic) 
            for word_to_be_defined in vocab_to_define
        ]

        results = compute(*values, scheduler=args.scheduler)

    else:
        results = [
            define_word(word_to_be_defined, vocab_to_use, vector_array, labels_to_define_to_num_dic)
            for word_to_be_defined in vocab_to_define
        ]

    for result in results:
        word_to_be_defined = result['word']
        components_dic[word_to_be_defined] = result['components']
        distance_dic[word_to_be_defined] = result['distance']
        
    save_to_disk(components_dic,distance_dic,outfilename1)
    end_time = time.time()
    logger.info(f"Done. total elapsed time was {end_time - start_time:g} seconds")

if __name__ == "__main__":
    args = parser.parse_args()
    main(args)